import torch
from torch import nn


# 前馈神经网络
# 前馈神经网络主要针对多头注意力机制的结果进行非线性调整（精细化参数权重）
class PointWiseFFN(nn.Module):
    def __init__(self, ffn_input_size, ffn_hidden_size, ffn_output_size):
        super().__init__()

        self.ln1 = nn.Linear(ffn_input_size, ffn_hidden_size)
        self.relu = nn.ReLU()
        self.ln2 = nn.Linear(ffn_hidden_size, ffn_output_size)

    def forward(self, x):
        return self.ln2(self.relu(self.ln1(x)))
